Pokemon classification with a Support Vector Machine
BSHT Michielsen MSc
This notebook demonstrates how to use a Support Vector Machine (SVM) for image classification. Image recognition is the ability for the computer to identify an object in the image based on the visual characteristics of that object. This is a classification problem, where each possible object is a class, and the provided image should lead to 1 specific class with a as high as feasible certainty. In order to train a classification model with this, a large number of images of the same object are needed. Relative to this notebook there should be a folder named data in which several Pokemon images are found. These images are a subset of the Pokemon collection by Lance Zhang which were picked for the fact that the selected Pokemon have strikingly different colors and therefore the machine can hopefully distinguish them fairly well. More images for the same Pokemon or even different Pokemon can be downloaded and added to the data folder.
First, the versions of the required libraries are shown. It always wise to report the versions of the libraries used so that in case problems arise in the future, one can still go back to a state in which the notebook worked.
import copy, pathlib, math
import PIL.Image as Image
import sklearn
import numpy
import matplotlib
import matplotlib.pyplot as plt
print("scikit-learn version:", sklearn.__version__) # 1.1.3
print("numpy version:", numpy.__version__) # 1.23.4
print("matplotlib version:", matplotlib.__version__) # 3.6.2
scikit-learn version: 1.4.1.post1 numpy version: 1.26.4 matplotlib version: 3.8.3
📦 Data provisioning¶
In real life the data provisioning phase is likely to include more steps about data sourcing and data quality, however for demo purposes in this notebook it is restricted to merely loading the images from the data folder, without any concern over quantity nor quality.
The code below will load the images and understand that the subfolder names are the class labels. It is important that all the images are the same size (and in this case square as well) so this code will automatically resize them. If high resolution images are available the size parameter can be increased and it will probably improve the performance slightly, at significantly increased training time. The given size of 256 is a middle way which is supposed to give fair results at a reasonable training time.
size = 256
def load_image(file, size):
img = Image.open(file).convert('RGB')
img = img.resize((size, size))
return numpy.array(img).flatten()
def load_labelled_images(path, size):
labels = list()
files = list()
for file_info in [x for x in pathlib.Path(path).glob("**/*.jpg")]:
labels.append(file_info.parts[1])
files.append(str(file_info))
imgs = numpy.array([load_image(f, size) for f in files])
return imgs, numpy.array(labels)
images, labels = load_labelled_images("./CatsAndDogsData", size)
# images, labels = load_labelled_images("./PokemonData", size)
# images, labels = load_labelled_images("./data", size)
print("Loaded", len(images), "images in the following", len(numpy.unique(labels)), "classes:")
for label in numpy.unique(labels):
print(label)
Loaded 636 images in the following 2 classes: cats dogs
📃 Sample the data¶
To get an impression of the data, here a sample from the loaded images is plotted so see if they we loaded correctly. The parameter sample_size can be increased if more images should be shown.
sample_size = 24
plotimgs = copy.deepcopy(images)
numpy.random.shuffle(plotimgs)
rows = plotimgs[:sample_size]
_, subplots = plt.subplots(nrows = math.ceil(sample_size/8), ncols = 8, figsize=(18, int(sample_size/3)))
subplots = subplots.flatten()
for i, x in enumerate(rows):
subplots[i].imshow(numpy.reshape(x, [size, size, 3]))
subplots[i].set_xticks([])
subplots[i].set_yticks([])
🛠️ Preprocessing¶
Given that this case uses images, there is no such thing as feature selection because one cannot select some pixels to be better indicators than other pixels beforehand. Therefore, there is little to do in terms of preprocessing other than splitting the dataset into a trainset and testset.
🪓 Splitting into train/test
A split of 70%/30% is chosen here in order to have a fairly large number of testing images.
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(images, labels, test_size=.3, random_state=0)
🧬 Modelling¶
In this step the model will be fitted with the trainset only. In this case a Support Vector Machine for classification.
from sklearn.svm import SVC
kernels = ['linear', 'poly', 'sigmoid', 'rbf']
for kernel in kernels:
model = SVC(kernel=kernel)
model.fit(X_train, y_train)
score = model.score(X_test, y_test)
print(f'Accuracy with {kernel} kernel: {score}')
🔬 Evaluation¶
Below a classification report is printed. This shows for every one of the classes how well the model performed.
from sklearn.metrics import classification_report
predictions = model.predict(X_test)
report = classification_report(y_test, predictions)
print(report)
precision recall f1-score support
cats 0.57 0.57 0.57 94
dogs 0.58 0.58 0.58 97
accuracy 0.58 191
macro avg 0.58 0.58 0.58 191
weighted avg 0.58 0.58 0.58 191
It appears that Mewtwo is fairly hard to recognize, but the others all seem well. The code below will plot every pokemon in the testset, including the predicted label as well as whether this was correct or wrong.
_, subplots = plt.subplots(nrows = math.ceil(len(X_test)/4), ncols = 4, figsize=(15, len(X_test)))
subplots = subplots.flatten()
for i, x in enumerate(X_test):
subplots[i].imshow(numpy.reshape(x, [size, size, 3]))
subplots[i].set_xticks([])
subplots[i].set_yticks([])
subplots[i].set_title(predictions[i] + (" (correct)" if predictions[i] == y_test[i] else " (wrong)"))